from knowledge_tracing.args import ARGS
import random
import torch
import numpy as np
from datasets.dataset_parser import Interaction, Constants


random.seed(ARGS.random_seed)
np.random.seed(ARGS.random_seed)
torch.manual_seed(ARGS.random_seed)
torch.cuda.manual_seed(ARGS.random_seed)

data_constants = Constants(ARGS.dataset_name, ARGS.data_root)


def generate_random_indices(seq_size, prob):
    """
    generate random indices
    """
    dnum = np.random.binomial(seq_size, prob, 1)[0]
    return sorted(random.sample(range(seq_size), dnum))


def random_replacement(sequence, rep_prob):
    """
    replace interaction with random interaction
    """
    qnum = data_constants.NUM_ITEMS
    snum = data_constants.NUM_TAGS
    interaction_features = vars(sequence[0]).keys()
    seq_len = len(sequence)
    replace_num = int(rep_prob * seq_len)
    if replace_num == 0:
        return sequence

    replace_locations = generate_random_indices(seq_len, rep_prob)

    for i in replace_locations:
        if 'item_idx' in interaction_features:
            new_qid = random.sample(range(1, qnum + 1), 1)[0]
            sequence[i].item_idx = new_qid
            sequence[i].is_correct = random.sample([0, 1], 1)[0]
            if 'tags' in interaction_features:
                sequence[i].tags = data_constants.QID_TO_SIDS[new_qid]
        else:
            new_sid = random.sample(range(1, snum + 1), 1)[0]
            sequence[i].tags = new_sid
            sequence[i].is_correct = random.sample([0, 1], 1)[0]
    return sequence


def get_random_q_id_from_old_q_id(q_id):
    old_sids = data_constants.QID_TO_SIDS[q_id]
    if len(old_sids) > 0:
        new_sid = random.sample(old_sids, 1)[0]
        return random.sample(data_constants.SID_TO_QIDS[new_sid], 1)[0]
    else:
        # no skill -> randomly smaple qid
        qnum = data_constants.NUM_ITEMS
        return random.sample(range(1, qnum+1), 1)[0]


def skill_based_replacement(sequence, rep_prob, rep_type='skill', response='all'):
    """
    Replace interactions with question with same skills & same correctness
    """
    seq_len = len(sequence)
    qnum = data_constants.NUM_ITEMS
    if response == 'all':
        replace_locations = generate_random_indices(seq_len, rep_prob)
    else:  # fixed response replacement
        candidate_locations = [i for i in range(seq_len) if sequence[i].is_correct == int(response)]
        dnum = np.random.binomial(len(candidate_locations), rep_prob, 1)[0]
        replace_locations = sorted(random.sample(candidate_locations, dnum))
    for i in replace_locations:
        # assume that features include 'item_idx'
        if rep_type == 'skill':
            old_qid = sequence[i].item_idx
            new_qid = get_random_q_id_from_old_q_id(old_qid)
            sequence[i].item_idx = new_qid
        elif rep_type == 'dif-skill':
            old_qid = sequence[i].item_idx
            old_sids = data_constants.QID_TO_SIDS[old_qid]
            old_qid_cands = set()
            for s in old_sids:
                qids = set(data_constants.SID_TO_QIDS[s])
                old_qid_cands = old_qid_cands.union(qids)
            new_qid_cands = set(range(1, data_constants.NUM_ITEMS+1)) - old_qid_cands
            new_qid_cands = sorted(list(new_qid_cands))
            new_qid = random.sample(new_qid_cands, 1)[0]
            sequence[i].item_idx = new_qid
        elif rep_type == 'q-rand' or rep_type == 'i-rand':
            new_qid = random.sample(range(1, qnum+1), 1)[0]
            sequence[i].item_idx = new_qid
            sequence[i].tags = data_constants.QID_TO_SIDS[new_qid]
            if rep_type == 'i-rand':
                sequence[i].is_correct = random.sample([0, 1], 1)[0]
        if sequence[i].interaction_idx is not None:
            sequence[i].interaction_idx = new_qid * 2 - sequence[i].is_correct
    not_replace_locations = [i for i in range(seq_len) if i not in replace_locations]
    return sequence, not_replace_locations


def random_deletion(sequence, del_prob, response=None, return_idx=False):
    """
    delete interactions randomly
    Args:
        sequence: list of interactions
        del_prob: random removing probability
    """
    seq_len = len(sequence)
    if response is None:  # random deletion, response doesn't matter
        delete_locations = generate_random_indices(seq_len, del_prob)
    elif response in ['0', '1']:  # fixed response deletion
        response = int(response)
        candidate_locations = [i for i in range(seq_len) if sequence[i].is_correct == response]
        dnum = np.random.binomial(len(candidate_locations), del_prob, 1)[0]  # number of interactions to be deleted
        delete_locations = sorted(random.sample(candidate_locations, dnum))  # indices of interactions to be deleted
    else:
        raise NotImplementedError
    sequence = list(np.delete(sequence, delete_locations))
    not_delete_locations = sorted(list(set(range(seq_len)) - set(delete_locations)))
    return sequence, not_delete_locations


def random_insertion(sequence, ins_prob, response='gc', ins_type='random', return_idx=False):
    """
    replace interaction with random interaction
    """
    qnum = data_constants.NUM_ITEMS
    snum = data_constants.NUM_TAGS
    interaction_features = vars(sequence[0]).keys()
    seq_len = len(sequence)

    if ins_type == 'skill':
        # the set of skills of questions in original sequence.
        skills = set()
        for i in sequence:
            for s in i.tags:
                skills.add(s)
        skills = sorted(list(skills))
        # the set of questions that has common skill with
        # at least one of the questions in the original sequence
        qid_candidates = set()
        for s in skills:
            for q in data_constants.SID_TO_QIDS[s]:
                qid_candidates.add(q)
        qid_candidates = sorted(list(qid_candidates))

    insert_num = np.random.binomial(seq_len, ins_prob)
    insert_num = min(insert_num, ARGS.seq_size - seq_len)
    if insert_num == 0:
        return sequence, list(range(ARGS.seq_size))
    insert_locations = sorted(random.sample(range(seq_len + insert_num), insert_num))

    # generate random sequence of interactions
    insert_inters = []
    for _ in range(insert_num):
        if 'item_idx' in interaction_features:
            if ins_type == 'random':
                new_qid = random.sample(range(1, qnum+1), 1)[0]
            elif ins_type == 'skill':
                if len(qid_candidates) > 0:
                    new_qid = random.sample(qid_candidates, 1)[0]
                else:
                    new_qid = random.sample(range(1, qnum+1), 1)[0]
            # new_qid
            if response == 'rand':
                new_is_correct = random.sample([0, 1], 1)[0]
            else:  # fixed response, '1' or '0'
                new_is_correct = int(response)
            new_interaction_idx = 2 * new_qid - new_is_correct
            new_tags = data_constants.QID_TO_SIDS[new_qid]
            random_inter = Interaction(
                item_idx=new_qid,
                is_correct=new_is_correct,
                tags=new_tags,
                interaction_idx=new_interaction_idx
            )
        else:
            # only tags
            new_sid = random.sample(range(1, snum+1), 1)[0]
            random_inter.tags = new_sid
            if response == 'rand':
                random_inter.is_correct = random.sample([0, 1], 1)[0]
            else:  # fixed response, '1' or '0'
                random_inter.is_correct = int(response)
            # TODO: interaction_idx
        insert_inters.append(random_inter)

    inserts = dict(zip(insert_locations, insert_inters))
    input = iter(sequence)
    sequence[:] = [inserts[pos] if pos in inserts else next(input) for pos in range(seq_len+insert_num)]
    not_insert_locations = sorted(list(set(range(seq_len+insert_num)) - set(insert_locations)))
    return sequence, not_insert_locations
